self
\(Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V\)
-
\(X\): context-free embedding of the input sequence
- \(X \in R^{n \times d_{model}}\)
- \(n\): sequence length
- \(d_{model}\): embedding dimension
-
\(Q = X W_Q\)
-
\(K = X W_K\)
-
\(V = X W_V\)
- where \(W_Q, W_K, W_V\) are learnable parameters, called projection matrices
- \(W_Q \in R^{d_{model} \times d_k}\): what to look for
- \(W_K \in R^{d_{model} \times d_k}\): what to compare with
- \(W_V \in R^{d_{model} \times d_v}\): what information to extract
-
Attention:
- score: \(S = QK^T\)
- scaled: \(S_{scaled} = \frac{S}{\sqrt{d_k}}\)
- softmax: \(A = softmax(S_{scaled})\)
-
Final output: \(Attention(Q, K, V) = AV\)
- the whole process lifts the context-free embedding into a contextualized representation
-
\(d_k\) and \(d_v\) are hyperparameters
- often \(d_k = d_v = d_{model} / h\)
- \(h\): number of attention heads
Causal Masking:
- mask = [ [0, -∞, -∞], [0, 0, -∞], [0, 0, 0] ]
- \(A = softmax(S_{scaled} + mask)\)
- In decoder self-attention, we're generating from left to right. During training, we have the full target sequence available, but we don't want to attend to future positions.
Scalability:
- Time complexity: \(O(n^2 d_{model})\)
- Space complexity: \(O(n^2)\)
- for storing the attention matrix (\(S_{scaled}\))
- Quadratic scaling in context length is the reason why transformer struggles in long context
Multi-head:
- \(MultiHead(Q, K, V) = Concat(head_1, head_2, ..., head_h) W^O\)
- \(head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)\)
Parameter calculation:
- \(W_Q^i, W_K^i \in R^{d_{model} \times d_k \times h}\) for each head \(i\)
- \(W_V^i \in R^{d_{model} \times d_v \times h}\) for each head \(i\)
- \(W^O \in R^{h \times d_v \times d_{model}}\)